package org.zjvis.datascience.common.algo;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import org.apache.commons.lang3.StringUtils;
import org.zjvis.datascience.common.constant.SqlTemplate;
import org.zjvis.datascience.common.enums.AlgEnum;
import org.zjvis.datascience.common.enums.SubTypeEnum;
import org.zjvis.datascience.common.sql.SqlHelper;
import org.zjvis.datascience.common.util.ToolUtil;
import org.zjvis.datascience.common.vo.TaskVO;

import java.util.List;


/**
 * @description DBSCAN 算子模板类
 * @date 2021-12-24
 */
public class DBscanAlg extends BaseAlg {

    private final String TPL_FILENAME = "template/algo/dbscan.json";

    public DBscanAlg() {
        super(AlgEnum.DBSCAN.name(), SubTypeEnum.CLUSTER.getVal(), SubTypeEnum.CLUSTER.getDesc());
        this.maxParentNumber = 1;
    }

    private static String DBSCAN_SQL_MADLIB = "select * from \"%s\".\"dbscan\"('%s', '%s', '%s', '%s', %f, %d, '%s')";

    private static String DBSCAN_SQL_MADLIB_SAMPLE = "select * from \"%s\".\"dbscan\"('CREATE VIEW %s AS SELECT * from %s where \"%s\" <= %s', '%s', '%s', '%s', %f, %d, '%s')";

    private static String DBSCAN_SQL_SPARK = "dbscan -s %s -f %s -t %s -eps %s -minp %s -maxp %s -idcol %s -uk %s";

    public String getDbscanSql(String sourceTable, String outTable, String featureCols, float eps,
                               int minSamples, String metric, long timeStamp, int minPoints, int maxPointsPerPartition,
                               String sampleTable) {
        if (StringUtils.isEmpty(featureCols)) {
            return StringUtils.EMPTY;
        }
        if (StringUtils.isNotEmpty(sampleTable)) {
            // 采样，走madlib引擎
            return String.format(DBSCAN_SQL_MADLIB_SAMPLE, SqlTemplate.SCHEMA, sampleTable,
                    sourceTable, ID_COL, SAMPLE_NUMBER, outTable, ID_COL, featureCols, eps, minSamples,
                    metric);
        } else {
            // 全量, 根据配置
            if (getEngine().isMadlib()) {
                return String
                        .format(DBSCAN_SQL_MADLIB, SqlTemplate.SCHEMA, sourceTable, outTable, ID_COL,
                                featureCols, eps, minSamples, metric);
            } else if (getEngine().isSpark()) {
                return String
                        .format(DBSCAN_SQL_SPARK, sourceTable, featureCols, outTable, eps, minPoints,
                                maxPointsPerPartition, ID_COL, timeStamp);
            }
        }
        return StringUtils.EMPTY;
    }

    /**
     * initSql
     *
     * @param json
     * @return sql string
     */
    public String initSql(JSONObject json, List<SqlHelper> sqlHelpers, long timeStamp,
                          String engineName) {
        this.engineName = engineName;
        String sourceTable = json.getString("source_table");
        sourceTable = ToolUtil.alignTableName(sourceTable, timeStamp);
        String outTable = json.getString("out_table_rename");
        outTable = ToolUtil.alignTableName(outTable, timeStamp);
        JSONArray features = json.getJSONArray("feature_cols");
        StringBuffer featureCols = new StringBuffer();
        if (features.size() == 0) {
            return null;
        }
        for (int i = 0; i < features.size(); ++i) {
            String feature = features.getString(i);
            if (feature.contains(".")) {
                String[] tmps = feature.split("\\.");
                feature = tmps[tmps.length - 1];
            }
            featureCols.append(feature);
            if (i != features.size() - 1) {
                featureCols.append(",");
            }
        }
        float eps = json.getFloat("eps");
        int minSamples = json.getInteger("min_samples");
        String metric = json.getString("metric");
        int minPoints = json.getInteger("min_points");
        int maxPointsPerPartition = json.getInteger("max_points_per_partition");
        String sampleTable = "";
        if (!json.containsKey("isSample") || json.getString("isSample").equals("SUCCESS") || json
                .getString("isSample").equals("FAIL")) {
            sampleTable = outTable.replace("solid_", "view_");
            json.put("isSample", "CREATE");
        }
        return getDbscanSql(sourceTable, outTable, featureCols.toString(), eps, minSamples,
                metric, timeStamp, minPoints, maxPointsPerPartition, sampleTable);
    }

    public void initTemplate(JSONObject data) {
        JSONArray jsonArray = getTemplateParamList(TPL_FILENAME);
        JSONArray outputCols = new JSONArray();
        data.put("setParams", jsonArray);
        baseInitTemplate(data);
        outputCols.add("id");
        outputCols.add("label");
        data.put("outputCols", outputCols);
        JSONArray validate = new JSONArray();
        validate.add("feature_cols,number");
        data.put("validate", validate);
    }

    public void defineOutput(TaskVO vo) {
        JSONObject jsonObject = vo.getData();
        String outTablePrefix = jsonObject.getString("out_table");
        String tableName = String
                .format(SqlTemplate.OUT_TABLE_NAME, outTablePrefix, vo.getPipelineId(), vo.getId());
        jsonObject.put("out_table_rename", tableName);
        JSONArray input = jsonObject.getJSONArray("input");
        JSONArray outputColumnTypes = new JSONArray();
        if (input == null || input.size() == 0) {
            return;
        }
        if (input.size() > 0) {
            String sourceTable = input.getJSONObject(0).getString("tableName");
            jsonObject.put("source_table", sourceTable);
            this.checkBoxSelectFilter(jsonObject, "number", FEATURE_COLS);
            this.supplementForCheckbox(jsonObject, TPL_FILENAME, 4, vo);
        }
        outputColumnTypes.add(ID_TYPE);
        outputColumnTypes.add("text");
        JSONArray jsonArray = new JSONArray();
        JSONObject item = new JSONObject();
        item.put("tableName", tableName);
        JSONArray outputCols = new JSONArray();
        outputCols.add(ID_COL);
        outputCols.add("cluster_id");
        item.put("tableCols", outputCols);
        item.put("nodeName", vo.getName() == null ? "DBSCAN" : vo.getName());
        item.put("columnTypes", outputColumnTypes);
        this.setSubTypeForOutput(item);
        jsonArray.add(item);
        jsonObject.put("output", jsonArray);
        vo.setData(jsonObject);
    }
}
